{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "# Simple Regression Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-28T08:21:08.550489Z",
     "iopub.status.busy": "2025-09-28T08:21:08.550262Z",
     "iopub.status.idle": "2025-09-28T08:21:12.185143Z",
     "shell.execute_reply": "2025-09-28T08:21:12.183947Z"
    }
   },
   "outputs": [],
   "source": [
    "from river import (\n",
    "    metrics,\n",
    "    compose,\n",
    "    preprocessing,\n",
    "    datasets,\n",
    "    stats,\n",
    "    feature_extraction,\n",
    ")\n",
    "from deep_river.regression import Regressor\n",
    "from torch import nn\n",
    "from pprint import pprint\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-28T08:21:12.188228Z",
     "iopub.status.busy": "2025-09-28T08:21:12.187876Z",
     "iopub.status.idle": "2025-09-28T08:21:12.195091Z",
     "shell.execute_reply": "2025-09-28T08:21:12.193905Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'clouds': 75,\n",
      " 'description': 'light rain',\n",
      " 'humidity': 81,\n",
      " 'moment': datetime.datetime(2016, 4, 1, 0, 0, 7),\n",
      " 'pressure': 1017.0,\n",
      " 'station': 'metro-canal-du-midi',\n",
      " 'temperature': 6.54,\n",
      " 'wind': 9.3}\n",
      "Number of available bikes: 1\n"
     ]
    }
   ],
   "source": [
    "dataset = datasets.Bikes()\n",
    "\n",
    "for x, y in dataset:\n",
    "    pprint(x)\n",
    "    print(f\"Number of available bikes: {y}\")\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-28T08:21:12.198008Z",
     "iopub.status.busy": "2025-09-28T08:21:12.197767Z",
     "iopub.status.idle": "2025-09-28T08:21:12.202609Z",
     "shell.execute_reply": "2025-09-28T08:21:12.201996Z"
    }
   },
   "outputs": [],
   "source": [
    "class MyModule(nn.Module):\n",
    "    def __init__(self, n_features):\n",
    "        super(MyModule, self).__init__()\n",
    "        self.dense0 = nn.Linear(n_features, 5)\n",
    "        self.nonlin = nn.ReLU()\n",
    "        self.dense1 = nn.Linear(5, 1)\n",
    "        self.softmax = nn.Softmax(dim=-1)\n",
    "\n",
    "    def forward(self, X, **kwargs):\n",
    "        X = self.nonlin(self.dense0(X))\n",
    "        X = self.nonlin(self.dense1(X))\n",
    "        X = self.softmax(X)\n",
    "        return X\n",
    "\n",
    "\n",
    "def get_hour(x):\n",
    "    x[\"hour\"] = x[\"moment\"].hour\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-28T08:21:12.204995Z",
     "iopub.status.busy": "2025-09-28T08:21:12.204730Z",
     "iopub.status.idle": "2025-09-28T08:21:13.232439Z",
     "shell.execute_reply": "2025-09-28T08:21:13.231681Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><div class=\"river-component river-pipeline\"><div class=\"river-component river-union\"><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">['clouds', [...]</pre></summary><code class=\"river-estimator-params\">Select (\n",
       "  clouds\n",
       "  humidity\n",
       "  pressure\n",
       "  temperature\n",
       "  wind\n",
       ")\n",
       "</code></details><div class=\"river-component river-pipeline\"><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">get_hour</pre></summary><code class=\"river-estimator-params\">\n",
       "def get_hour(x):\n",
       "    x[\"hour\"] = x[\"moment\"].hour\n",
       "    return x\n",
       "\n",
       "</code></details><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">y_mean_by_station_and_hour</pre></summary><code class=\"river-estimator-params\">TargetAgg (\n",
       "  by=['station', 'hour']\n",
       "  how=Mean ()\n",
       "  target_name=\"y\"\n",
       ")\n",
       "</code></details></div></div><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">StandardScaler</pre></summary><code class=\"river-estimator-params\">StandardScaler (\n",
       "  with_std=True\n",
       ")\n",
       "</code></details><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">Regressor</pre></summary><code class=\"river-estimator-params\">Regressor (\n",
       "  module=MyModule(\n",
       "  (dense0): Linear(in_features=10, out_features=5, bias=True)\n",
       "  (nonlin): ReLU()\n",
       "  (dense1): Linear(in_features=5, out_features=1, bias=True)\n",
       "  (softmax): Softmax(dim=-1)\n",
       ")\n",
       "  loss_fn=\"mse\"\n",
       "  optimizer_fn=\"sgd\"\n",
       "  lr=0.001\n",
       "  is_feature_incremental=False\n",
       "  device=\"cpu\"\n",
       "  seed=42\n",
       ")\n",
       "</code></details></div><style scoped>\n",
       ".river-estimator {\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "    max-width: max-content;\n",
       "}\n",
       "\n",
       ".river-pipeline {\n",
       "    display: flex;\n",
       "    flex-direction: column;\n",
       "    align-items: center;\n",
       "    background: linear-gradient(#000, #000) no-repeat center / 1.5px 100%;\n",
       "}\n",
       "\n",
       ".river-union {\n",
       "    display: flex;\n",
       "    flex-direction: row;\n",
       "    align-items: center;\n",
       "    justify-content: center;\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "}\n",
       "\n",
       ".river-wrapper {\n",
       "    display: flex;\n",
       "    flex-direction: column;\n",
       "    align-items: center;\n",
       "    justify-content: center;\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "}\n",
       "\n",
       ".river-wrapper > .river-estimator {\n",
       "    margin-top: 1em;\n",
       "}\n",
       "\n",
       "/* Vertical spacing between steps */\n",
       "\n",
       ".river-component + .river-component {\n",
       "    margin-top: 2em;\n",
       "}\n",
       "\n",
       ".river-union > .river-estimator {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       ".river-union > .river-component {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       ".river-union > .pipeline {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       "/* Spacing within a union of estimators */\n",
       "\n",
       ".river-union > .river-component + .river-component {\n",
       "    margin-left: 1em;\n",
       "}\n",
       "\n",
       "/* Typography */\n",
       "\n",
       ".river-estimator-params {\n",
       "    display: block;\n",
       "    white-space: pre-wrap;\n",
       "    font-size: 110%;\n",
       "    margin-top: 1em;\n",
       "}\n",
       "\n",
       ".river-estimator > .river-estimator-params,\n",
       ".river-wrapper > .river-details > river-estimator-params {\n",
       "    background-color: white !important;\n",
       "}\n",
       "\n",
       ".river-wrapper > .river-details {\n",
       "    margin-bottom: 1em;\n",
       "}\n",
       "\n",
       ".river-estimator-name {\n",
       "    display: inline;\n",
       "    margin: 0;\n",
       "    font-size: 110%;\n",
       "}\n",
       "\n",
       "/* Toggle */\n",
       "\n",
       ".river-summary {\n",
       "    display: flex;\n",
       "    align-items:center;\n",
       "    cursor: pointer;\n",
       "}\n",
       "\n",
       ".river-summary > div {\n",
       "    width: 100%;\n",
       "}\n",
       "</style></div>"
      ],
      "text/plain": [
       "Pipeline (\n",
       "  TransformerUnion (\n",
       "    Select (\n",
       "      clouds\n",
       "      humidity\n",
       "      pressure\n",
       "      temperature\n",
       "      wind\n",
       "    ),\n",
       "    Pipeline (\n",
       "      FuncTransformer (\n",
       "        func=\"get_hour\"\n",
       "      ),\n",
       "      TargetAgg (\n",
       "        by=['station', 'hour']\n",
       "        how=Mean ()\n",
       "        target_name=\"y\"\n",
       "      )\n",
       "    )\n",
       "  ),\n",
       "  StandardScaler (\n",
       "    with_std=True\n",
       "  ),\n",
       "  Regressor (\n",
       "    module=MyModule(\n",
       "    (dense0): Linear(in_features=10, out_features=5, bias=True)\n",
       "    (nonlin): ReLU()\n",
       "    (dense1): Linear(in_features=5, out_features=1, bias=True)\n",
       "    (softmax): Softmax(dim=-1)\n",
       "  )\n",
       "    loss_fn=\"mse\"\n",
       "    optimizer_fn=\"sgd\"\n",
       "    lr=0.001\n",
       "    is_feature_incremental=False\n",
       "    device=\"cpu\"\n",
       "    seed=42\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "metric = metrics.MAE()\n",
    "\n",
    "model_pipeline = compose.Select(\n",
    "    \"clouds\", \"humidity\", \"pressure\", \"temperature\", \"wind\"\n",
    ")\n",
    "model_pipeline += get_hour | feature_extraction.TargetAgg(\n",
    "    by=[\"station\", \"hour\"], how=stats.Mean()\n",
    ")\n",
    "model_pipeline |= preprocessing.StandardScaler()\n",
    "model_pipeline |= Regressor(module=MyModule(10), loss_fn=\"mse\", optimizer_fn=\"sgd\")\n",
    "model_pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-28T08:21:13.235389Z",
     "iopub.status.busy": "2025-09-28T08:21:13.235085Z",
     "iopub.status.idle": "2025-09-28T08:21:17.926382Z",
     "shell.execute_reply": "2025-09-28T08:21:17.925370Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "53it [00:00, 525.82it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "157it [00:00, 821.97it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "258it [00:00, 906.00it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "365it [00:00, 969.28it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "474it [00:00, 1010.64it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "585it [00:00, 1041.72it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "695it [00:00, 1059.12it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "806it [00:00, 1072.91it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "914it [00:00, 1074.63it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "1028it [00:01, 1091.78it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "1147it [00:01, 1118.35it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "1259it [00:01, 1106.12it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "1370it [00:01, 1066.01it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "1479it [00:01, 1072.76it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "1587it [00:01, 1073.15it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "1698it [00:01, 1083.37it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "1808it [00:01, 1087.79it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "1918it [00:01, 1090.43it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "2028it [00:01, 1088.89it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "2138it [00:02, 1092.02it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "2251it [00:02, 1103.27it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "2365it [00:02, 1113.66it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "2477it [00:02, 1108.26it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "2591it [00:02, 1115.86it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "2703it [00:02, 1102.41it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "2814it [00:02, 1095.06it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "2934it [00:02, 1123.38it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "3047it [00:02, 1123.59it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "3160it [00:02, 1114.18it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "3272it [00:03, 1067.75it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "3380it [00:03, 1063.18it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "3487it [00:03, 1064.14it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "3594it [00:03, 1052.28it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "3700it [00:03, 1048.48it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "3808it [00:03, 1054.64it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "3916it [00:03, 1059.51it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "4023it [00:03, 1045.39it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "4129it [00:03, 1048.36it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "4239it [00:03, 1061.50it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "4350it [00:04, 1074.71it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "4458it [00:04, 1072.10it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "4566it [00:04, 1073.53it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "4675it [00:04, 1077.60it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "4784it [00:04, 1079.27it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "4897it [00:04, 1092.86it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "5000it [00:04, 1070.85it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MAE: 6.83\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for x, y in tqdm(dataset.take(5000)):\n",
    "    y_pred = model_pipeline.predict_one(x)\n",
    "    metric.update(y_true=y, y_pred=y_pred)\n",
    "    model_pipeline.learn_one(x=x, y=y)\n",
    "print(f\"MAE: {metric.get():.2f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "deep-river",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
